/**
 * @file KNN.cpp
 * @author enemy1205 (enemy1205@qq.com)
 * @brief KNN执行类
 * @date 2021-09-24
 */
#include "KNN.h"

/**
 * @brief 获得数据任意方差
 * @param l 数据向量起始迭代器
 * @param r 数据向量终止迭代器
 * @param axis 所求维度
 * @return axis维度方差
 */
double Math::getVariance(const vector<pair<vector<double>, uint> >::iterator &l,
                         const vector<pair<vector<double>, uint> >::iterator &r,
                         const uint &axis) {
    //根据数据的第axis维度求方差
    double aver = 0;
    uint N = r - l;
    for (auto i = l; i < r; i++) {
        aver += ((*i).first)[axis];
    }
    aver /= static_cast<double>(N);
    double sq_diff = 0;
    for (auto i = l; i < r; i++) {
        double temp = ((*i).first)[axis] - aver;
        sq_diff += temp * temp;
    }
    return sq_diff / static_cast<double>(N - 1);
}

double Math::getDistance(const vector<double> &P1, const vector<double> &P2) {
    double distance_sq=0;
    assert(P1.size() == P2.size());
    for (size_t i = 0; i < P1.size(); ++i) {
        distance_sq += (P1[i] - P2[i]) * (P1[i] - P2[i]);
    }
    return distance_sq;
}

/**
 * @brief 给定测试案例
 */
void KNN::test_knn() {
    vector<vector<double> > data = {{2, 3},
                                    {5, 4},
                                    {9, 6},
                                    {4, 7},
                                    {8, 1},
                                    {7, 2}};
    KdTree T(data);
    auto neighbors = T.findKneighbor(vector<double>({7, 4}), 4);
    for (auto it = neighbors.begin(); it < neighbors.end(); ++it) {
        cout << "(";
        for (auto it2 = data[(*it).first].begin(); it2 < data[(*it).first].end(); ++it2) {
            if (data[(*it).first].end() - it2 == 1) {
                cout << (*it2) << ") ";
            } else {
                cout << (*it2) << ",";
            }
        }
        cout << (*it).second;
        cout << endl;
    }
}

/**
 * @brief 自定义数据测试
 * @param data 数据集
 * @param P 待寻找数据
 * @param k 搜寻个数
 */
void KNN::run(vector<vector<double>> &data, vector<double> &P, const int &k) {
    KdTree T(data);
    auto neighbors = T.findKneighbor(P, k);
    //格式输出(x,x,x,x) dis
    for (auto neighbor:neighbors) {
        cout << "(";
        for (auto it=data[neighbor.first].begin(); it < data[neighbor.first].end(); ++it) {
            if (it==data[neighbor.first].begin()) cout<<(*it);
            else cout<<","<<(*it);
        }
        cout <<") "<< neighbor.second;
        cout << endl;
    }
}